package com.garbagecode.resultbounds;

import org.apache.ibatis.executor.Executor;

import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import com.garbagecode.resultbounds.bounds.ResultBounds;
import com.garbagecode.resultbounds.bounds.TotalCountBounds;
import com.garbagecode.resultbounds.utilty.MappedStatementModifier;
import com.garbagecode.resultbounds.utilty.TemporarySqlSource;

@Intercepts({ @Signature(type = Executor.class, 
                         method = "query", 
                         args = { MappedStatement.class, 
                                  Object.class,
                                  RowBounds.class, 
                                  ResultHandler.class }) })
public class PaginationInterceptor extends AbstractPaginationInterceptor {

  @SuppressWarnings("rawtypes")
  @Override
  public Object intercept(Invocation invocation) throws Throwable {
    Executor executor = (Executor) invocation.getTarget();
    Object[] args = invocation.getArgs();
    MappedStatement mappedStatement = (MappedStatement) args[0];
    Object parameter = args[1];
    RowBounds rowBounds = (RowBounds) args[2];

    // 若 rowBounds 是 TotalCountBounds 对象，表示这次查询语句是要统计符合查询条件的记录总数
    if (rowBounds instanceof TotalCountBounds) {
      return getTotal(executor, args, mappedStatement, parameter);
    }
    
    //
    // 若当前上下文中有 ResultBounds 对象，那使用该 ResultBounds 对象过滤查询结果
    ResultBounds currentBounds = ResultBounds.getCurrentBounds();
    
    if (currentBounds != null) {
      rowBounds = currentBounds;
    }

    ResultBounds resultBounds = getResultBounds(rowBounds);
    // 若 rowBounds 是 ResultBounds 对象，那么给它设置一个 TotalCountBounds 对象，
    // 这样就可以通过 ResultBoudns.getTotal 获取符合查询条件的记录总数
    if (rowBounds instanceof ResultBounds) {
      setTotal(resultBounds, mappedStatement, parameter, invocation);
    }

    // 若需要分页的话...
    if (isNeedPagination(resultBounds)) {
      BoundSql boundSql = mappedStatement.getBoundSql(parameter);

      String paginationSql = getDialect().getPaginationSql(
          boundSql.getSql(), 
          resultBounds.getOffset(),
          resultBounds.getLimit(), 
          resultBounds.getOrder());

      SqlSource temporarySqlSource = new TemporarySqlSource(
          mappedStatement.getConfiguration(), 
          paginationSql,
          boundSql.getParameterMappings(), 
          boundSql.getParameterObject());

      MappedStatement newMappedStatement = new MappedStatementModifier()
          .modify(mappedStatement, temporarySqlSource, null);

      return executor.query(newMappedStatement, parameter, RowBounds.DEFAULT, (ResultHandler) args[3]);
    }

    return invocation.proceed();
  }
  
  
  /**
   * 设置符合查询条件的记录总数
   * @param resultBounds
   * @param mappedStatement
   * @param parameter
   * @param invocation
   */
  protected void setTotal(ResultBounds resultBounds, 
                          MappedStatement mappedStatement, 
                          Object parameter, 
                          Invocation invocation) {
    resultBounds.setTotal(new TotalCountBounds(mappedStatement.getId(), parameter));
  }
  
  /**
   * 返回符合查询条件的记录总数
   * @param executor
   * @param args
   * @param mappedStatement
   * @param parameter
   * @return
   * @throws Throwable
   */
  @SuppressWarnings("rawtypes")
  protected Object getTotal(Executor executor, 
                            Object[] args, 
                            MappedStatement mappedStatement, 
                            Object parameter) throws Throwable {
    BoundSql boundSql = mappedStatement.getBoundSql(parameter);
    String totalCountSql = getDialect().getTotalCountSql(boundSql.getSql());
    
    SqlSource newSqlSource = new TemporarySqlSource(
        mappedStatement.getConfiguration(), 
        totalCountSql, 
        boundSql.getParameterMappings(),
        boundSql.getParameterObject());

    MappedStatement newMappedStatement = new MappedStatementModifier().modify(mappedStatement, newSqlSource, getResultMaps());

    return executor.query(newMappedStatement, parameter, RowBounds.DEFAULT, (ResultHandler) args[3]);
  }
  
  /**
   * 把 rowBounds 转换成 ResultBounds 对象
   * @param rowBounds
   * @return
   */
  protected ResultBounds getResultBounds(RowBounds rowBounds) {
    if (rowBounds instanceof ResultBounds) {
      return (ResultBounds) rowBounds;
    }

    return (rowBounds != null) 
        ? new ResultBounds(rowBounds) 
        : null;
  }
  
  /**
   * 判断是否需要分页
   * @param resultBounds
   * @return
   */
  protected boolean isNeedPagination(ResultBounds resultBounds) {
    if (resultBounds == null) {
      throw new IllegalArgumentException("resultBounds null");
    }
    
    return resultBounds.getOffset() != RowBounds.NO_ROW_OFFSET || 
        resultBounds.getLimit() != RowBounds.NO_ROW_LIMIT || 
        (resultBounds.getOrder() != null && resultBounds.getOrder().trim().length() > 0);
  }
  
}
